#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""Sources and sinks.

A Source manages record-oriented data input from a particular kind of source
(e.g. a set of files, a database table, etc.). The reader() method of a source
returns a reader object supporting the iterator protocol; iteration yields
raw records of unprocessed, serialized data.


A Sink manages record-oriented data output to a particular kind of sink
(e.g. a set of files, a database table, etc.). The writer() method of a sink
returns a writer object supporting writing records of serialized data to
the sink.
"""

from __future__ import absolute_import
from __future__ import division

import logging
import math
import random
import uuid
from builtins import object
from builtins import range
from collections import namedtuple

from apache_beam import coders
from apache_beam import pvalue
from apache_beam.portability import common_urns
from apache_beam.portability import python_urns
from apache_beam.portability.api import beam_runner_api_pb2
from apache_beam.pvalue import AsIter
from apache_beam.pvalue import AsSingleton
from apache_beam.transforms import core
from apache_beam.transforms import ptransform
from apache_beam.transforms import window
from apache_beam.transforms.display import DisplayDataItem
from apache_beam.transforms.display import HasDisplayData
from apache_beam.utils import timestamp
from apache_beam.utils import urns
from apache_beam.utils.windowed_value import WindowedValue

__all__ = ['BoundedSource', 'RangeTracker', 'Read', 'RestrictionTracker',
           'Sink', 'Write', 'Writer']


# Encapsulates information about a bundle of a source generated when method
# BoundedSource.split() is invoked.
# This is a named 4-tuple that has following fields.
# * weight - a number that represents the size of the bundle. This value will
#            be used to compare the relative sizes of bundles generated by the
#            current source.
#            The weight returned here could be specified using a unit of your
#            choice (for example, bundles of sizes 100MB, 200MB, and 700MB may
#            specify weights 100, 200, 700 or 1, 2, 7) but all bundles of a
#            source should specify the weight using the same unit.
# * source - a BoundedSource object for the  bundle.
# * start_position - starting position of the bundle
# * stop_position - ending position of the bundle.
#
# Type for start and stop positions are specific to the bounded source and must
# be consistent throughout.
SourceBundle = namedtuple(
    'SourceBundle',
    'weight source start_position stop_position')


class SourceBase(HasDisplayData, urns.RunnerApiFn):
  """Base class for all sources that can be passed to beam.io.Read(...).
  """
  urns.RunnerApiFn.register_pickle_urn(python_urns.PICKLED_SOURCE)


class BoundedSource(SourceBase):
  """A source that reads a finite amount of input records.

  This class defines following operations which can be used to read the source
  efficiently.

  * Size estimation - method ``estimate_size()`` may return an accurate
    estimation in bytes for the size of the source.
  * Splitting into bundles of a given size - method ``split()`` can be used to
    split the source into a set of sub-sources (bundles) based on a desired
    bundle size.
  * Getting a RangeTracker - method ``get_range_tracker()`` should return a
    ``RangeTracker`` object for a given position range for the position type
    of the records returned by the source.
  * Reading the data - method ``read()`` can be used to read data from the
    source while respecting the boundaries defined by a given
    ``RangeTracker``.

  A runner will perform reading the source in two steps.

  (1) Method ``get_range_tracker()`` will be invoked with start and end
      positions to obtain a ``RangeTracker`` for the range of positions the
      runner intends to read. Source must define a default initial start and end
      position range. These positions must be used if the start and/or end
      positions passed to the method ``get_range_tracker()`` are ``None``
  (2) Method read() will be invoked with the ``RangeTracker`` obtained in the
      previous step.

  **Mutability**

  A ``BoundedSource`` object should not be mutated while
  its methods (for example, ``read()``) are being invoked by a runner. Runner
  implementations may invoke methods of ``BoundedSource`` objects through
  multi-threaded and/or reentrant execution modes.
  """

  def estimate_size(self):
    """Estimates the size of source in bytes.

    An estimate of the total size (in bytes) of the data that would be read
    from this source. This estimate is in terms of external storage size,
    before performing decompression or other processing.

    Returns:
      estimated size of the source if the size can be determined, ``None``
      otherwise.
    """
    raise NotImplementedError

  def split(self, desired_bundle_size, start_position=None, stop_position=None):
    """Splits the source into a set of bundles.

    Bundles should be approximately of size ``desired_bundle_size`` bytes.

    Args:
      desired_bundle_size: the desired size (in bytes) of the bundles returned.
      start_position: if specified the given position must be used as the
                      starting position of the first bundle.
      stop_position: if specified the given position must be used as the ending
                     position of the last bundle.
    Returns:
      an iterator of objects of type 'SourceBundle' that gives information about
      the generated bundles.
    """
    raise NotImplementedError

  def get_range_tracker(self, start_position, stop_position):
    """Returns a RangeTracker for a given position range.

    Framework may invoke ``read()`` method with the RangeTracker object returned
    here to read data from the source.

    Args:
      start_position: starting position of the range. If 'None' default start
                      position of the source must be used.
      stop_position:  ending position of the range. If 'None' default stop
                      position of the source must be used.
    Returns:
      a ``RangeTracker`` for the given position range.
    """
    raise NotImplementedError

  def read(self, range_tracker):
    """Returns an iterator that reads data from the source.

    The returned set of data must respect the boundaries defined by the given
    ``RangeTracker`` object. For example:

      * Returned set of data must be for the range
        ``[range_tracker.start_position, range_tracker.stop_position)``. Note
        that a source may decide to return records that start after
        ``range_tracker.stop_position``. See documentation in class
        ``RangeTracker`` for more details. Also, note that framework might
        invoke ``range_tracker.try_split()`` to perform dynamic split
        operations. range_tracker.stop_position may be updated
        dynamically due to successful dynamic split operations.
      * Method ``range_tracker.try_split()`` must be invoked for every record
        that starts at a split point.
      * Method ``range_tracker.record_current_position()`` may be invoked for
        records that do not start at split points.

    Args:
      range_tracker: a ``RangeTracker`` whose boundaries must be respected
                     when reading data from the source. A runner that reads this
                     source muss pass a ``RangeTracker`` object that is not
                     ``None``.
    Returns:
      an iterator of data read by the source.
    """
    raise NotImplementedError

  def default_output_coder(self):
    """Coder that should be used for the records returned by the source.

    Should be overridden by sources that produce objects that can be encoded
    more efficiently than pickling.
    """
    return coders.registry.get_coder(object)

  def is_bounded(self):
    return True


class RangeTracker(object):
  """A thread safe object used by Dataflow source framework.

  A Dataflow source is defined using a ''BoundedSource'' and a ''RangeTracker''
  pair. A ''RangeTracker'' is used by Dataflow source framework to perform
  dynamic work rebalancing of position-based sources.

  **Position-based sources**

  A position-based source is one where the source can be described by a range
  of positions of an ordered type and the records returned by the reader can be
  described by positions of the same type.

  In case a record occupies a range of positions in the source, the most
  important thing about the record is the position where it starts.

  Defining the semantics of positions for a source is entirely up to the source
  class, however the chosen definitions have to obey certain properties in order
  to make it possible to correctly split the source into parts, including
  dynamic splitting. Two main aspects need to be defined:

  1. How to assign starting positions to records.
  2. Which records should be read by a source with a range '[A, B)'.

  Moreover, reading a range must be *efficient*, i.e., the performance of
  reading a range should not significantly depend on the location of the range.
  For example, reading the range [A, B) should not require reading all data
  before 'A'.

  The sections below explain exactly what properties these definitions must
  satisfy, and how to use a ``RangeTracker`` with a properly defined source.

  **Properties of position-based sources**

  The main requirement for position-based sources is *associativity*: reading
  records from '[A, B)' and records from '[B, C)' should give the same
  records as reading from '[A, C)', where 'A <= B <= C'. This property
  ensures that no matter how a range of positions is split into arbitrarily many
  sub-ranges, the total set of records described by them stays the same.

  The other important property is how the source's range relates to positions of
  records in the source. In many sources each record can be identified by a
  unique starting position. In this case:

  * All records returned by a source '[A, B)' must have starting positions in
    this range.
  * All but the last record should end within this range. The last record may or
    may not extend past the end of the range.
  * Records should not overlap.

  Such sources should define "read '[A, B)'" as "read from the first record
  starting at or after 'A', up to but not including the first record starting
  at or after 'B'".

  Some examples of such sources include reading lines or CSV from a text file,
  reading keys and values from a BigTable, etc.

  The concept of *split points* allows to extend the definitions for dealing
  with sources where some records cannot be identified by a unique starting
  position.

  In all cases, all records returned by a source '[A, B)' must *start* at or
  after 'A'.

  **Split points**

  Some sources may have records that are not directly addressable. For example,
  imagine a file format consisting of a sequence of compressed blocks. Each
  block can be assigned an offset, but records within the block cannot be
  directly addressed without decompressing the block. Let us refer to this
  hypothetical format as <i>CBF (Compressed Blocks Format)</i>.

  Many such formats can still satisfy the associativity property. For example,
  in CBF, reading '[A, B)' can mean "read all the records in all blocks whose
  starting offset is in '[A, B)'".

  To support such complex formats, we introduce the notion of *split points*. We
  say that a record is a split point if there exists a position 'A' such that
  the record is the first one to be returned when reading the range
  '[A, infinity)'. In CBF, the only split points would be the first records
  in each block.

  Split points allow us to define the meaning of a record's position and a
  source's range in all cases:

  * For a record that is at a split point, its position is defined to be the
    largest 'A' such that reading a source with the range '[A, infinity)'
    returns this record.
  * Positions of other records are only required to be non-decreasing.
  * Reading the source '[A, B)' must return records starting from the first
    split point at or after 'A', up to but not including the first split point
    at or after 'B'. In particular, this means that the first record returned
    by a source MUST always be a split point.
  * Positions of split points must be unique.

  As a result, for any decomposition of the full range of the source into
  position ranges, the total set of records will be the full set of records in
  the source, and each record will be read exactly once.

  **Consumed positions**

  As the source is being read, and records read from it are being passed to the
  downstream transforms in the pipeline, we say that positions in the source are
  being *consumed*. When a reader has read a record (or promised to a caller
  that a record will be returned), positions up to and including the record's
  start position are considered *consumed*.

  Dynamic splitting can happen only at *unconsumed* positions. If the reader
  just returned a record at offset 42 in a file, dynamic splitting can happen
  only at offset 43 or beyond, as otherwise that record could be read twice (by
  the current reader and by a reader of the task starting at 43).
  """

  SPLIT_POINTS_UNKNOWN = object()

  def start_position(self):
    """Returns the starting position of the current range, inclusive."""
    raise NotImplementedError(type(self))

  def stop_position(self):
    """Returns the ending position of the current range, exclusive."""
    raise NotImplementedError(type(self))

  def try_claim(self, position):  # pylint: disable=unused-argument
    """Atomically determines if a record at a split point is within the range.

    This method should be called **if and only if** the record is at a split
    point. This method may modify the internal state of the ``RangeTracker`` by
    updating the last-consumed position to ``position``.

    ** Thread safety **

    Methods of the class ``RangeTracker`` including this method may get invoked
    by different threads, hence must be made thread-safe, e.g. by using a single
    lock object.

    Args:
      position: starting position of a record being read by a source.

    Returns:
      ``True``, if the given position falls within the current range, returns
      ``False`` otherwise.
    """
    raise NotImplementedError

  def set_current_position(self, position):
    """Updates the last-consumed position to the given position.

    A source may invoke this method for records that do not start at split
    points. This may modify the internal state of the ``RangeTracker``. If the
    record starts at a split point, method ``try_claim()`` **must** be invoked
    instead of this method.

    Args:
      position: starting position of a record being read by a source.
    """
    raise NotImplementedError

  def position_at_fraction(self, fraction):
    """Returns the position at the given fraction.

    Given a fraction within the range [0.0, 1.0) this method will return the
    position at the given fraction compared to the position range
    [self.start_position, self.stop_position).

    ** Thread safety **

    Methods of the class ``RangeTracker`` including this method may get invoked
    by different threads, hence must be made thread-safe, e.g. by using a single
    lock object.

    Args:
      fraction: a float value within the range [0.0, 1.0).
    Returns:
      a position within the range [self.start_position, self.stop_position).
    """
    raise NotImplementedError

  def try_split(self, position):
    """Atomically splits the current range.

    Determines a position to split the current range, split_position, based on
    the given position. In most cases split_position and position will be the
    same.

    Splits the current range '[self.start_position, self.stop_position)'
    into a "primary" part '[self.start_position, split_position)' and a
    "residual" part '[split_position, self.stop_position)', assuming the
    current last-consumed position is within
    '[self.start_position, split_position)' (i.e., split_position has not been
    consumed yet).

    If successful, updates the current range to be the primary and returns a
    tuple (split_position, split_fraction). split_fraction should be the
    fraction of size of range '[self.start_position, split_position)' compared
    to the original (before split) range
    '[self.start_position, self.stop_position)'.

    If the split_position has already been consumed, returns ``None``.

    ** Thread safety **

    Methods of the class ``RangeTracker`` including this method may get invoked
    by different threads, hence must be made thread-safe, e.g. by using a single
    lock object.

    Args:
      position: suggested position where the current range should try to
                be split at.
    Returns:
      a tuple containing the split position and split fraction if split is
      successful. Returns ``None`` otherwise.
    """
    raise NotImplementedError

  def fraction_consumed(self):
    """Returns the approximate fraction of consumed positions in the source.

    ** Thread safety **

    Methods of the class ``RangeTracker`` including this method may get invoked
    by different threads, hence must be made thread-safe, e.g. by using a single
    lock object.

    Returns:
      the approximate fraction of positions that have been consumed by
      successful 'try_split()' and  'try_claim()'  calls, or
      0.0 if no such calls have happened.
    """
    raise NotImplementedError

  def split_points(self):
    """Gives the number of split points consumed and remaining.

    For a ``RangeTracker`` used by a ``BoundedSource`` (within a
    ``BoundedSource.read()`` invocation) this method produces a 2-tuple that
    gives the number of split points consumed by the ``BoundedSource`` and the
    number of split points remaining within the range of the ``RangeTracker``
    that has not been consumed by the ``BoundedSource``.

    More specifically, given that the position of the current record being read
    by ``BoundedSource`` is current_position this method produces a tuple that
    consists of
    (1) number of split points in the range [self.start_position(),
    current_position) without including the split point that is currently being
    consumed. This represents the total amount of parallelism in the consumed
    part of the source.
    (2) number of split points within the range
    [current_position, self.stop_position()) including the split point that is
    currently being consumed. This represents the total amount of parallelism in
    the unconsumed part of the source.

    Methods of the class ``RangeTracker`` including this method may get invoked
    by different threads, hence must be made thread-safe, e.g. by using a single
    lock object.

    ** General information about consumed and remaining number of split
       points returned by this method. **

      * Before a source read (``BoundedSource.read()`` invocation) claims the
        first split point, number of consumed split points is 0. This condition
        holds independent of whether the input is "splittable". A splittable
        source is a source that has more than one split point.
      * Any source read that has only claimed one split point has 0 consumed
        split points since the first split point is the current split point and
        is still being processed. This condition holds independent of whether
        the input is splittable.
      * For an empty source read which never invokes
        ``RangeTracker.try_claim()``, the consumed number of split points is 0.
        This condition holds independent of whether the input is splittable.
      * For a source read which has invoked ``RangeTracker.try_claim()`` n
        times, the consumed number of split points is  n -1.
      * If a ``BoundedSource`` sets a callback through function
        ``set_split_points_unclaimed_callback()``, ``RangeTracker`` can use that
        callback when determining remaining number of split points.
      * Remaining split points should include the split point that is currently
        being consumed by the source read. Hence if the above callback returns
        an integer value n, remaining number of split points should be (n + 1).
      * After last split point is claimed remaining split points becomes 1,
        because this unfinished read itself represents an  unfinished split
        point.
      * After all records of the source has been consumed, remaining number of
        split points becomes 0 and consumed number of split points becomes equal
        to the total number of split points within the range being read by the
        source. This method does not address this condition and will continue to
        report number of consumed split points as
        ("total number of split points" - 1) and number of remaining split
        points as 1. A runner that performs the reading of the source can
        detect when all records have been consumed and adjust remaining and
        consumed number of split points accordingly.

    ** Examples **

    (1) A "perfectly splittable" input which can be read in parallel down to the
        individual records.

        Consider a perfectly splittable input that consists of 50 split points.

      * Before a source read (``BoundedSource.read()`` invocation) claims the
        first split point, number of consumed split points is 0 number of
        remaining split points is 50.
      * After claiming first split point, consumed number of split points is 0
        and remaining number of split is 50.
      * After claiming split point #30, consumed number of split points is 29
        and remaining number of split points is 21.
      * After claiming all 50 split points, consumed number of split points is
        49 and remaining number of split points is 1.

    (2) a "block-compressed" file format such as ``avroio``, in which a block of
        records has to be read as a whole, but different blocks can be read in
        parallel.

        Consider a block compressed input that consists of 5 blocks.

      * Before a source read (``BoundedSource.read()`` invocation) claims the
        first split point (first block), number of consumed split points is 0
        number of remaining split points is 5.
      * After claiming first split point, consumed number of split points is 0
        and remaining number of split is 5.
      * After claiming split point #3, consumed number of split points is 2
        and remaining number of split points is 3.
      * After claiming all 5 split points, consumed number of split points is
        4 and remaining number of split points is 1.

    (3) an "unsplittable" input such as a cursor in a database or a gzip
        compressed file.

        Such an input is considered to have only a single split point. Number of
        consumed split points is always 0 and number of remaining split points
        is always 1.

    By default ``RangeTracker` returns ``RangeTracker.SPLIT_POINTS_UNKNOWN`` for
    both consumed and remaining number of split points, which indicates that the
    number of split points consumed and remaining is unknown.

    Returns:
      A pair that gives consumed and remaining number of split points. Consumed
      number of split points should be an integer larger than or equal to zero
      or ``RangeTracker.SPLIT_POINTS_UNKNOWN``. Remaining number of split points
      should be an integer larger than zero or
      ``RangeTracker.SPLIT_POINTS_UNKNOWN``.
    """
    return (RangeTracker.SPLIT_POINTS_UNKNOWN,
            RangeTracker.SPLIT_POINTS_UNKNOWN)

  def set_split_points_unclaimed_callback(self, callback):
    """Sets a callback for determining the unclaimed number of split points.

    By invoking this function, a ``BoundedSource`` can set a callback function
    that may get invoked by the ``RangeTracker`` to determine the number of
    unclaimed split points. A split point is unclaimed if
    ``RangeTracker.try_claim()`` method has not been successfully invoked for
    that particular split point. The callback function accepts a single
    parameter, a stop position for the BoundedSource (stop_position). If the
    record currently being consumed by the ``BoundedSource`` is at position
    current_position, callback should return the number of split points within
    the range (current_position, stop_position). Note that, this should not
    include the split point that is currently being consumed by the source.

    This function must be implemented by subclasses before being used.

    Args:
      callback: a function that takes a single parameter, a stop position,
                and returns unclaimed number of split points for the source read
                operation that is calling this function. Value returned from
                callback should be either an integer larger than or equal to
                zero or ``RangeTracker.SPLIT_POINTS_UNKNOWN``.
    """
    raise NotImplementedError


class Sink(HasDisplayData):
  """This class is deprecated, no backwards-compatibility guarantees.

  A resource that can be written to using the ``beam.io.Write`` transform.

  Here ``beam`` stands for Apache Beam Python code imported in following manner.
  ``import apache_beam as beam``.

  A parallel write to an ``iobase.Sink`` consists of three phases:

  1. A sequential *initialization* phase (e.g., creating a temporary output
     directory, etc.)
  2. A parallel write phase where workers write *bundles* of records
  3. A sequential *finalization* phase (e.g., committing the writes, merging
     output files, etc.)

  Implementing a new sink requires extending two classes.

  1. iobase.Sink

  ``iobase.Sink`` is an immutable logical description of the location/resource
  to write to. Depending on the type of sink, it may contain fields such as the
  path to an output directory on a filesystem, a database table name,
  etc. ``iobase.Sink`` provides methods for performing a write operation to the
  sink described by it. To this end, implementors of an extension of
  ``iobase.Sink`` must implement three methods:
  ``initialize_write()``, ``open_writer()``, and ``finalize_write()``.

  2. iobase.Writer

  ``iobase.Writer`` is used to write a single bundle of records. An
  ``iobase.Writer`` defines two methods: ``write()`` which writes a
  single record from the bundle and ``close()`` which is called once
  at the end of writing a bundle.

  See also ``apache_beam.io.filebasedsink.FileBasedSink`` which provides a
  simpler API for writing sinks that produce files.

  **Execution of the Write transform**

  ``initialize_write()``, ``pre_finalize()``, and ``finalize_write()`` are
  conceptually called once. However, implementors must
  ensure that these methods are *idempotent*, as they may be called multiple
  times on different machines in the case of failure/retry. A method may be
  called more than once concurrently, in which case it's okay to have a
  transient failure (such as due to a race condition). This failure should not
  prevent subsequent retries from succeeding.

  ``initialize_write()`` should perform any initialization that needs to be done
  prior to writing to the sink. ``initialize_write()`` may return a result
  (let's call this ``init_result``) that contains any parameters it wants to
  pass on to its writers about the sink. For example, a sink that writes to a
  file system may return an ``init_result`` that contains a dynamically
  generated unique directory to which data should be written.

  To perform writing of a bundle of elements, Dataflow execution engine will
  create an ``iobase.Writer`` using the implementation of
  ``iobase.Sink.open_writer()``. When invoking ``open_writer()`` execution
  engine will provide the ``init_result`` returned by ``initialize_write()``
  invocation as well as a *bundle id* (let's call this ``bundle_id``) that is
  unique for each invocation of ``open_writer()``.

  Execution engine will then invoke ``iobase.Writer.write()`` implementation for
  each element that has to be written. Once all elements of a bundle are
  written, execution engine will invoke ``iobase.Writer.close()`` implementation
  which should return a result (let's call this ``write_result``) that contains
  information that encodes the result of the write and, in most cases, some
  encoding of the unique bundle id. For example, if each bundle is written to a
  unique temporary file, ``close()`` method may return an object that contains
  the temporary file name. After writing of all bundles is complete, execution
  engine will invoke ``pre_finalize()`` and then ``finalize_write()``
  implementation.

  The execution of a write transform can be illustrated using following pseudo
  code (assume that the outer for loop happens in parallel across many
  machines)::

    init_result = sink.initialize_write()
    write_results = []
    for bundle in partition(pcoll):
      writer = sink.open_writer(init_result, generate_bundle_id())
      for elem in bundle:
        writer.write(elem)
      write_results.append(writer.close())
    pre_finalize_result = sink.pre_finalize(init_result, write_results)
    sink.finalize_write(init_result, write_results, pre_finalize_result)


  **init_result**

  Methods of 'iobase.Sink' should agree on the 'init_result' type that will be
  returned when initializing the sink. This type can be a client-defined object
  or an existing type. The returned type must be picklable using Dataflow coder
  ``coders.PickleCoder``. Returning an init_result is optional.

  **bundle_id**

  In order to ensure fault-tolerance, a bundle may be executed multiple times
  (e.g., in the event of failure/retry or for redundancy). However, exactly one
  of these executions will have its result passed to the
  ``iobase.Sink.finalize_write()`` method. Each call to
  ``iobase.Sink.open_writer()`` is passed a unique bundle id when it is called
  by the ``WriteImpl`` transform, so even redundant or retried bundles will have
  a unique way of identifying their output.

  The bundle id should be used to guarantee that a bundle's output is unique.
  This uniqueness guarantee is important; if a bundle is to be output to a file,
  for example, the name of the file must be unique to avoid conflicts with other
  writers. The bundle id should be encoded in the writer result returned by the
  writer and subsequently used by the ``finalize_write()`` method to identify
  the results of successful writes.

  For example, consider the scenario where a Writer writes files containing
  serialized records and the ``finalize_write()`` is to merge or rename these
  output files. In this case, a writer may use its unique id to name its output
  file (to avoid conflicts) and return the name of the file it wrote as its
  writer result. The ``finalize_write()`` will then receive an ``Iterable`` of
  output file names that it can then merge or rename using some bundle naming
  scheme.

  **write_result**

  ``iobase.Writer.close()`` and ``finalize_write()`` implementations must agree
  on type of the ``write_result`` object returned when invoking
  ``iobase.Writer.close()``. This type can be a client-defined object or
  an existing type. The returned type must be picklable using Dataflow coder
  ``coders.PickleCoder``. Returning a ``write_result`` when
  ``iobase.Writer.close()`` is invoked is optional but if unique
  ``write_result`` objects are not returned, sink should, guarantee idempotency
  when same bundle is written multiple times due to failure/retry or redundancy.


  **More information**

  For more information on creating new sinks please refer to the official
  documentation at
  ``https://beam.apache.org/documentation/sdks/python-custom-io#creating-sinks``
  """

  def initialize_write(self):
    """Initializes the sink before writing begins.

    Invoked before any data is written to the sink.


    Please see documentation in ``iobase.Sink`` for an example.

    Returns:
      An object that contains any sink specific state generated by
      initialization. This object will be passed to open_writer() and
      finalize_write() methods.
    """
    raise NotImplementedError

  def open_writer(self, init_result, uid):
    """Opens a writer for writing a bundle of elements to the sink.

    Args:
      init_result: the result of initialize_write() invocation.
      uid: a unique identifier generated by the system.
    Returns:
      an ``iobase.Writer`` that can be used to write a bundle of records to the
      current sink.
    """
    raise NotImplementedError

  def pre_finalize(self, init_result, writer_results):
    """Pre-finalization stage for sink.

    Called after all bundle writes are complete and before finalize_write.
    Used to setup and verify filesystem and sink states.

    Args:
      init_result: the result of ``initialize_write()`` invocation.
      writer_results: an iterable containing results of ``Writer.close()``
        invocations. This will only contain results of successful writes, and
        will only contain the result of a single successful write for a given
        bundle.

    Returns:
      An object that contains any sink specific state generated.
      This object will be passed to finalize_write().
    """
    raise NotImplementedError

  def finalize_write(self, init_result, writer_results,
                     pre_finalize_result):
    """Finalizes the sink after all data is written to it.

    Given the result of initialization and an iterable of results from bundle
    writes, performs finalization after writing and closes the sink. Called
    after all bundle writes are complete.

    The bundle write results that are passed to finalize are those returned by
    bundles that completed successfully. Although bundles may have been run
    multiple times (for fault-tolerance), only one writer result will be passed
    to finalize for each bundle. An implementation of finalize should perform
    clean up of any failed and successfully retried bundles.  Note that these
    failed bundles will not have their writer result passed to finalize, so
    finalize should be capable of locating any temporary/partial output written
    by failed bundles.

    If all retries of a bundle fails, the whole pipeline will fail *without*
    finalize_write() being invoked.

    A best practice is to make finalize atomic. If this is impossible given the
    semantics of the sink, finalize should be idempotent, as it may be called
    multiple times in the case of failure/retry or for redundancy.

    Note that the iteration order of the writer results is not guaranteed to be
    consistent if finalize is called multiple times.

    Args:
      init_result: the result of ``initialize_write()`` invocation.
      writer_results: an iterable containing results of ``Writer.close()``
        invocations. This will only contain results of successful writes, and
        will only contain the result of a single successful write for a given
        bundle.
      pre_finalize_result: the result of ``pre_finalize()`` invocation.
    """
    raise NotImplementedError


class Writer(object):
  """This class is deprecated, no backwards-compatibility guarantees.

  Writes a bundle of elements from a ``PCollection`` to a sink.

  A Writer  ``iobase.Writer.write()`` writes and elements to the sink while
  ``iobase.Writer.close()`` is called after all elements in the bundle have been
  written.

  See ``iobase.Sink`` for more detailed documentation about the process of
  writing to a sink.
  """

  def write(self, value):
    """Writes a value to the sink using the current writer."""
    raise NotImplementedError

  def close(self):
    """Closes the current writer.

    Please see documentation in ``iobase.Sink`` for an example.

    Returns:
      An object representing the writes that were performed by the current
      writer.
    """
    raise NotImplementedError


class Read(ptransform.PTransform):
  """A transform that reads a PCollection."""

  def __init__(self, source):
    """Initializes a Read transform.

    Args:
      source: Data source to read from.
    """
    super(Read, self).__init__()
    self.source = source

  @staticmethod
  def get_desired_chunk_size(total_size):
    total_size
    if total_size:
      # 1MB = 1 shard, 1GB = 32 shards, 1TB = 1000 shards, 1PB = 32k shards
      chunk_size = max(1 << 20, 1000 * int(math.sqrt(total_size)))
    else:
      chunk_size = 64 << 20  # 64mb
    return chunk_size

  def expand(self, pbegin):
    from apache_beam.options.pipeline_options import DebugOptions
    from apache_beam.transforms import util

    assert isinstance(pbegin, pvalue.PBegin)
    self.pipeline = pbegin.pipeline

    debug_options = self.pipeline._options.view_as(DebugOptions)
    if debug_options.experiments and 'beam_fn_api' in debug_options.experiments:
      source = self.source

      def split_source(unused_impulse):
        return source.split(
            self.get_desired_chunk_size(self.source.estimate_size()))

      return (
          pbegin
          | core.Impulse()
          | 'Split' >> core.FlatMap(split_source)
          | util.Reshuffle()
          | 'ReadSplits' >> core.FlatMap(lambda split: split.source.read(
              split.source.get_range_tracker(
                  split.start_position, split.stop_position))))
    else:
      # Treat Read itself as a primitive.
      return pvalue.PCollection(self.pipeline)

  def get_windowing(self, unused_inputs):
    return core.Windowing(window.GlobalWindows())

  def _infer_output_coder(self, input_type=None, input_coder=None):
    if isinstance(self.source, BoundedSource):
      return self.source.default_output_coder()
    else:
      return self.source.coder

  def display_data(self):
    return {'source': DisplayDataItem(self.source.__class__,
                                      label='Read Source'),
            'source_dd': self.source}

  def to_runner_api_parameter(self, context):
    return (common_urns.deprecated_primitives.READ.urn,
            beam_runner_api_pb2.ReadPayload(
                source=self.source.to_runner_api(context),
                is_bounded=beam_runner_api_pb2.IsBounded.BOUNDED
                if self.source.is_bounded()
                else beam_runner_api_pb2.IsBounded.UNBOUNDED))

  @staticmethod
  def from_runner_api_parameter(parameter, context):
    return Read(SourceBase.from_runner_api(parameter.source, context))


ptransform.PTransform.register_urn(
    common_urns.deprecated_primitives.READ.urn,
    beam_runner_api_pb2.ReadPayload,
    Read.from_runner_api_parameter)


class Write(ptransform.PTransform):
  """A ``PTransform`` that writes to a sink.

  A sink should inherit ``iobase.Sink``. Such implementations are
  handled using a composite transform that consists of three ``ParDo``s -
  (1) a ``ParDo`` performing a global initialization (2) a ``ParDo`` performing
  a parallel write and (3) a ``ParDo`` performing a global finalization. In the
  case of an empty ``PCollection``, only the global initialization and
  finalization will be performed. Currently only batch workflows support custom
  sinks.

  Example usage::

      pcollection | beam.io.Write(MySink())

  This returns a ``pvalue.PValue`` object that represents the end of the
  Pipeline.

  The sink argument may also be a full PTransform, in which case it will be
  applied directly.  This allows composite sink-like transforms (e.g. a sink
  with some pre-processing DoFns) to be used the same as all other sinks.

  This transform also supports sinks that inherit ``iobase.NativeSink``. These
  are sinks that are implemented natively by the Dataflow service and hence
  should not be updated by users. These sinks are processed using a Dataflow
  native write transform.
  """

  def __init__(self, sink):
    """Initializes a Write transform.

    Args:
      sink: Data sink to write to.
    """
    super(Write, self).__init__()
    self.sink = sink

  def display_data(self):
    return {'sink': self.sink.__class__,
            'sink_dd': self.sink}

  def expand(self, pcoll):
    from apache_beam.runners.dataflow.native_io import iobase as dataflow_io
    if isinstance(self.sink, dataflow_io.NativeSink):
      # A native sink
      return pcoll | 'NativeWrite' >> dataflow_io._NativeWrite(self.sink)
    elif isinstance(self.sink, Sink):
      # A custom sink
      return pcoll | WriteImpl(self.sink)
    elif isinstance(self.sink, ptransform.PTransform):
      # This allows "composite" sinks to be used like non-composite ones.
      return pcoll | self.sink
    else:
      raise ValueError('A sink must inherit iobase.Sink, iobase.NativeSink, '
                       'or be a PTransform. Received : %r' % self.sink)


class WriteImpl(ptransform.PTransform):
  """Implements the writing of custom sinks."""

  def __init__(self, sink):
    super(WriteImpl, self).__init__()
    self.sink = sink

  def expand(self, pcoll):
    do_once = pcoll.pipeline | 'DoOnce' >> core.Create([None])
    init_result_coll = do_once | 'InitializeWrite' >> core.Map(
        lambda _, sink: sink.initialize_write(), self.sink)
    if getattr(self.sink, 'num_shards', 0):
      min_shards = self.sink.num_shards
      if min_shards == 1:
        keyed_pcoll = pcoll | core.Map(lambda x: (None, x))
      else:
        keyed_pcoll = pcoll | core.ParDo(_RoundRobinKeyFn(min_shards))
      write_result_coll = (keyed_pcoll
                           | core.WindowInto(window.GlobalWindows())
                           | core.GroupByKey()
                           | 'WriteBundles' >> core.ParDo(
                               _WriteKeyedBundleDoFn(self.sink),
                               AsSingleton(init_result_coll)))
    else:
      min_shards = 1
      write_result_coll = (pcoll
                           | 'WriteBundles' >>
                           core.ParDo(_WriteBundleDoFn(self.sink),
                                      AsSingleton(init_result_coll))
                           | 'Pair' >> core.Map(lambda x: (None, x))
                           | core.WindowInto(window.GlobalWindows())
                           | core.GroupByKey()
                           | 'Extract' >> core.FlatMap(lambda x: x[1]))
    # PreFinalize should run before FinalizeWrite, and the two should not be
    # fused.
    pre_finalize_coll = do_once | 'PreFinalize' >> core.FlatMap(
        _pre_finalize,
        self.sink,
        AsSingleton(init_result_coll),
        AsIter(write_result_coll))
    return do_once | 'FinalizeWrite' >> core.FlatMap(
        _finalize_write,
        self.sink,
        AsSingleton(init_result_coll),
        AsIter(write_result_coll),
        min_shards,
        AsSingleton(pre_finalize_coll))


class _WriteBundleDoFn(core.DoFn):
  """A DoFn for writing elements to an iobase.Writer.
  Opens a writer at the first element and closes the writer at finish_bundle().
  """

  def __init__(self, sink):
    self.sink = sink

  def display_data(self):
    return {'sink_dd': self.sink}

  def start_bundle(self):
    self.writer = None

  def process(self, element, init_result):
    if self.writer is None:
      # We ignore UUID collisions here since they are extremely rare.
      self.writer = self.sink.open_writer(init_result, str(uuid.uuid4()))
    self.writer.write(element)

  def finish_bundle(self):
    if self.writer is not None:
      yield WindowedValue(self.writer.close(),
                          window.GlobalWindow().max_timestamp(),
                          [window.GlobalWindow()])


class _WriteKeyedBundleDoFn(core.DoFn):

  def __init__(self, sink):
    self.sink = sink

  def display_data(self):
    return {'sink_dd': self.sink}

  def process(self, element, init_result):
    bundle = element
    writer = self.sink.open_writer(init_result, str(uuid.uuid4()))
    for e in bundle[1]:  # values
      writer.write(e)
    return [window.TimestampedValue(writer.close(), timestamp.MAX_TIMESTAMP)]


def _pre_finalize(unused_element, sink, init_result, write_results):
  return sink.pre_finalize(init_result, write_results)


def _finalize_write(unused_element, sink, init_result, write_results,
                    min_shards, pre_finalize_results):
  write_results = list(write_results)
  extra_shards = []
  if len(write_results) < min_shards:
    logging.debug(
        'Creating %s empty shard(s).', min_shards - len(write_results))
    for _ in range(min_shards - len(write_results)):
      writer = sink.open_writer(init_result, str(uuid.uuid4()))
      extra_shards.append(writer.close())
  outputs = sink.finalize_write(init_result, write_results + extra_shards,
                                pre_finalize_results)
  if outputs:
    return (
        window.TimestampedValue(v, timestamp.MAX_TIMESTAMP) for v in outputs)


class _RoundRobinKeyFn(core.DoFn):

  def __init__(self, count):
    self.count = count

  def start_bundle(self):
    self.counter = random.randint(0, self.count - 1)

  def process(self, element):
    self.counter += 1
    if self.counter >= self.count:
      self.counter -= self.count
    yield self.counter, element


class RestrictionTracker(object):
  """Manages concurrent access to a restriction.

  Experimental; no backwards-compatibility guarantees.

  Keeps track of the restrictions claimed part for a Splittable DoFn.

  See following documents for more details.
  * https://s.apache.org/splittable-do-fn
  * https://s.apache.org/splittable-do-fn-python-sdk
  """

  def current_restriction(self):
    """Returns the current restriction.

    Returns a restriction accurately describing the full range of work the
    current ``DoFn.process()`` call will do, including already completed work.

    The current restriction returned by method may be updated dynamically due
    to due to concurrent invocation of other methods of the
    ``RestrictionTracker``, For example, ``checkpoint()``.

    ** Thread safety **

    Methods of the class ``RestrictionTracker`` including this method may get
    invoked by different threads, hence must be made thread-safe, e.g. by using
    a single lock object.

    TODO(BEAM-7473): Remove thread safety requirements from API implementation.
    """
    raise NotImplementedError

  def current_progress(self):
    """Returns a RestrictionProgress object representing the current progress.
    """
    raise NotImplementedError

  def current_watermark(self):
    """Returns current watermark. By default, not report watermark.

    TODO(BEAM-7473): Provide synchronization guarantee by using a wrapper.
    """
    return None

  def checkpoint(self):
    """Performs a checkpoint of the current restriction.

    Signals that the current ``DoFn.process()`` call should terminate as soon as
    possible. After this method returns, the tracker MUST refuse all future
    claim calls, and ``RestrictionTracker.check_done()`` MUST succeed.

    This invocation modifies the value returned by ``current_restriction()``
    invocation and returns a restriction representing the rest of the work. The
    old value of ``current_restriction()`` is equivalent to the new value of
    ``current_restriction()`` and the return value of this method invocation
    combined.

    ** Thread safety **

    Methods of the class ``RestrictionTracker`` including this method may get
    invoked by different threads, hence must be made thread-safe, e.g. by using
    a single lock object.

    TODO(BEAM-7473): Remove thread safety requirements from API implementation.
    """

    raise NotImplementedError

  def check_done(self):
    """Checks whether the restriction has been fully processed.

    Called by the runner after iterator returned by ``DoFn.process()`` has been
    fully read.

    This method must raise a `ValueError` if there is still any unclaimed work
    remaining in the restriction when this method is invoked. Exception raised
    must have an informative error message.

    ** Thread safety **

    Methods of the class ``RestrictionTracker`` including this method may get
    invoked by different threads, hence must be made thread-safe, e.g. by using
    a single lock object.

    TODO(BEAM-7473): Remove thread safety requirements from API implementation.

    Returns: ``True`` if current restriction has been fully processed.
    Raises:
      ~exceptions.ValueError: if there is still any unclaimed work remaining.
    """
    raise NotImplementedError

  def try_split(self, fraction_of_remainder):
    """Splits current restriction based on fraction_of_remainder.

    If splitting the current restriction is possible, the current restriction is
    split into a primary and residual restriction pair. This invocation updates
    the ``current_restriction()`` to be the primary restriction effectively
    having the current ``DoFn.process()`` execution responsible for performing
    the work that the primary restriction represents. The residual restriction
    will be executed in a separate ``DoFn.process()`` invocation (likely in a
    different process). The work performed by executing the primary and residual
    restrictions as separate ``DoFn.process()`` invocations MUST be equivalent
    to the work performed as if this split never occurred.

    The ``fraction_of_remainder`` should be used in a best effort manner to
    choose a primary and residual restriction based upon the fraction of the
    remaining work that the current ``DoFn.process()`` invocation is responsible
    for. For example, if a ``DoFn.process()`` was reading a file with a
    restriction representing the offset range [100, 200) and has processed up to
    offset 130 with a fraction_of_remainder of 0.7, the primary and residual
    restrictions returned would be [100, 179), [179, 200) (note: current_offset
    + fraction_of_remainder * remaining_work = 130 + 0.7 * 70 = 179).

    It is very important for pipeline scaling and end to end pipeline execution
    that try_split is implemented well.

    Args:
      fraction_of_remainder: A hint as to the fraction of work the primary
        restriction should represent based upon the current known remaining
        amount of work.

    Returns:
      (primary_restriction, residual_restriction) if a split was possible,
      otherwise returns ``None``.

    ** Thread safety **

    Methods of the class ``RestrictionTracker`` including this method may get
    invoked by different threads, hence must be made thread-safe, e.g. by using
    a single lock object.

    TODO(BEAM-7473): Remove thread safety requirements from API implementation.
    """
    raise NotImplementedError

  def try_claim(self, position):
    """ Attempts to claim the block of work in the current restriction
    identified by the given position.

    If this succeeds, the DoFn MUST execute the entire block of work. If it
    fails, the ``DoFn.process()`` MUST return ``None`` without performing any
    additional work or emitting output (note that emitting output or performing
    work from ``DoFn.process()`` is also not allowed before the first call of
    this method).

    Args:
      position: current position that wants to be claimed.

    Returns: ``True`` if the position can be claimed as current_position.
    Otherwise, returns ``False``.

    ** Thread safety **

    Methods of the class ``RestrictionTracker`` including this method may get
    invoked by different threads, hence must be made thread-safe, e.g. by using
    a single lock object.

    TODO(BEAM-7473): Remove thread safety requirements from API implementation.
    """
    raise NotImplementedError

  def defer_remainder(self, watermark=None):
    """ Invokes checkpoint() in an SDF.process().

    TODO(BEAM-7472): Remove defer_remainder() once SDF.process() uses
    ``ProcessContinuation``.

    Args:
      watermark
    """
    raise NotImplementedError

  def deferred_status(self):
    """ Returns deferred_residual with deferred_watermark.

    TODO(BEAM-7472): Remove defer_status() once SDF.process() uses
    ``ProcessContinuation``.
    """
    raise NotImplementedError


class RestrictionProgress(object):
  """Used to record the progress of a restriction.

  Experimental; no backwards-compatibility guarantees.
  """
  def __init__(self, **kwargs):
    # Only accept keyword arguments.
    self._fraction = kwargs.pop('fraction', None)
    self._completed = kwargs.pop('completed', None)
    self._remaining = kwargs.pop('remaining', None)
    assert not kwargs

  def __repr__(self):
    return 'RestrictionProgress(fraction=%s, completed=%s, remaining=%s)' % (
        self._fraction, self._completed, self._remaining)

  @property
  def completed_work(self):
    if self._completed:
      return self._completed
    elif self._remaining and self._fraction:
      return self._remaining * self._fraction / (1 - self._fraction)

  @property
  def remaining_work(self):
    if self._remaining:
      return self._remaining
    elif self._completed:
      return self._completed * (1 - self._fraction) / self._fraction

  @property
  def total_work(self):
    return self.completed_work + self.remaining_work

  @property
  def fraction_completed(self):
    if self._fraction is not None:
      return self._fraction
    else:
      return float(self._completed) / self.total_work

  @property
  def fraction_remaining(self):
    if self._fraction is not None:
      return 1 - self._fraction
    else:
      return float(self._remaining) / self.total_work

  def with_completed(self, completed):
    return RestrictionProgress(
        fraction=self._fraction, remaining=self._remaining, completed=completed)


class _SDFBoundedSourceWrapper(ptransform.PTransform):
  """A ``PTransform`` that uses SDF to read from a ``BoundedSource``.

  NOTE: This transform can only be used with beam_fn_api enabled.
  """
  class _SDFBoundedSourceRestrictionTracker(RestrictionTracker):
    """An `iobase.RestrictionTracker` implementations for wrapping BoundedSource
    with SDF.

    Delegated RangeTracker guarantees synchronization safety.
    """
    def __init__(self, restriction):
      if not isinstance(restriction, SourceBundle):
        raise ValueError('Initializing SDFBoundedSourceRestrictionTracker'
                         'requires a SourceBundle')
      self._delegate_range_tracker = restriction.source.get_range_tracker(
          restriction.start_position, restriction.stop_position)
      self._source = restriction.source
      self._weight = restriction.weight

    def current_progress(self):
      return RestrictionProgress(
          fraction=self._delegate_range_tracker.fraction_consumed())

    def current_restriction(self):
      start_pos = self._delegate_range_tracker.start_position()
      stop_pos = self._delegate_range_tracker.stop_position()
      return SourceBundle(
          self._weight,
          self._source,
          start_pos,
          stop_pos)

    def start_pos(self):
      return self._delegate_range_tracker.start_position()

    def stop_pos(self):
      return self._delegate_range_tracker.stop_position()

    def try_claim(self, position):
      return self._delegate_range_tracker.try_claim(position)

    def try_split(self, fraction_of_remainder):
      consumed_fraction = self._delegate_range_tracker.fraction_consumed()
      fraction = (consumed_fraction +
                  (1 - consumed_fraction) * fraction_of_remainder)
      position = self._delegate_range_tracker.position_at_fraction(fraction)
      # Need to stash current stop_pos before splitting since
      # range_tracker.split will update its stop_pos if splits
      # successfully.
      start_pos = self.start_pos()
      stop_pos = self.stop_pos()
      split_result = self._delegate_range_tracker.try_split(position)
      if split_result:
        split_pos, split_fraction = split_result
        primary_weight = self._weight * split_fraction
        residual_weight = self._weight - primary_weight
        # Update self._weight to primary weight
        self._weight = primary_weight
        return (SourceBundle(primary_weight, self._source, start_pos,
                             split_pos),
                SourceBundle(residual_weight, self._source, split_pos,
                             stop_pos))

    def deferred_status(self):
      return None

    def current_watermark(self):
      return None

    def get_delegate_range_tracker(self):
      return self._delegate_range_tracker

    def get_tracking_source(self):
      return self._source

  class _SDFBoundedSourceRestrictionProvider(core.RestrictionProvider):
    """A `RestrictionProvider` that is used by SDF for `BoundedSource`."""

    def __init__(self, source, desired_chunk_size=None):
      self._source = source
      self._desired_chunk_size = desired_chunk_size

    def initial_restriction(self, element):
      # Get initial range_tracker from source
      range_tracker = self._source.get_range_tracker(None, None)
      return SourceBundle(None,
                          self._source,
                          range_tracker.start_position(),
                          range_tracker.stop_position())

    def create_tracker(self, restriction):
      return _SDFBoundedSourceWrapper._SDFBoundedSourceRestrictionTracker(
          restriction)

    def split(self, element, restriction):
      # Invoke source.split to get initial splitting results.
      source_bundles = self._source.split(self._desired_chunk_size)
      for source_bundle in source_bundles:
        yield source_bundle

    def restriction_size(self, element, restriction):
      return restriction.weight

  def __init__(self, source):
    if not isinstance(source, BoundedSource):
      raise RuntimeError('SDFBoundedSourceWrapper can only wrap BoundedSource')
    super(_SDFBoundedSourceWrapper, self).__init__()
    self.source = source

  def _create_sdf_bounded_source_dofn(self):
    source = self.source
    chunk_size = Read.get_desired_chunk_size(source.estimate_size())

    class SDFBoundedSourceDoFn(core.DoFn):
      def __init__(self, read_source):
        self.source = read_source

      def process(
          self,
          element,
          restriction_tracker=core.DoFn.RestrictionParam(
              _SDFBoundedSourceWrapper._SDFBoundedSourceRestrictionProvider(
                  source, chunk_size))):
        return restriction_tracker.get_tracking_source().read(
            restriction_tracker.get_delegate_range_tracker())

    return SDFBoundedSourceDoFn(self.source)

  def expand(self, pbegin):
    return (pbegin
            | core.Impulse()
            | core.ParDo(self._create_sdf_bounded_source_dofn()))

  def get_windowing(self, unused_inputs):
    return core.Windowing(window.GlobalWindows())

  def _infer_output_coder(self, input_type=None, input_coder=None):
    return self.source.default_output_coder()

  def display_data(self):
    return {'source': DisplayDataItem(self.source.__class__,
                                      label='Read Source'),
            'source_dd': self.source}
